{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TensorFlow2教程-keras模型保存和序列化\n",
    "\n",
    "## 1 保存序列模型或函数式模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/doit/anaconda3/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"3_layer_mlp\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "digits (InputLayer)          [(None, 784)]             0         \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (None, 64)                50240     \n",
      "_________________________________________________________________\n",
      "dense_2 (Dense)              (None, 64)                4160      \n",
      "_________________________________________________________________\n",
      "predictions (Dense)          (None, 10)                650       \n",
      "=================================================================\n",
      "Total params: 55,050\n",
      "Trainable params: 55,050\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "Train on 60000 samples\n",
      "60000/60000 [==============================] - 2s 28us/sample - loss: 0.3133\n"
     ]
    }
   ],
   "source": [
    "# 构建一个简单的模型并训练\n",
    "from __future__ import absolute_import, division, print_function\n",
    "import tensorflow as tf\n",
    "tf.keras.backend.clear_session()\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "\n",
    "inputs = keras.Input(shape=(784,), name='digits')\n",
    "x = layers.Dense(64, activation='relu', name='dense_1')(inputs)\n",
    "x = layers.Dense(64, activation='relu', name='dense_2')(x)\n",
    "outputs = layers.Dense(10, activation='softmax', name='predictions')(x)\n",
    "\n",
    "model = keras.Model(inputs=inputs, outputs=outputs, name='3_layer_mlp')\n",
    "model.summary()\n",
    "(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n",
    "x_train = x_train.reshape(60000, 784).astype('float32') / 255\n",
    "x_test = x_test.reshape(10000, 784).astype('float32') / 255\n",
    "\n",
    "model.compile(loss='sparse_categorical_crossentropy',\n",
    "              optimizer=keras.optimizers.RMSprop())\n",
    "history = model.fit(x_train, y_train,\n",
    "                    batch_size=64,\n",
    "                    epochs=1)\n",
    "\n",
    "predictions = model.predict(x_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.1 保存整个模型\n",
    "可以对整个模型进行保存，其保持的内容包括：\n",
    "- 该模型的架构\n",
    "- 模型的权重（在训练期间学到的）\n",
    "- 模型的训练配置（传递给编译的）\n",
    "- 优化器及其状态（这使您可以从中断的地方重新启动训练）\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "# 模型保存\n",
    "model.save('the_save_model.h5')\n",
    "# 导入模型\n",
    "new_model = keras.models.load_model('the_save_model.h5')\n",
    "new_prediction = new_model.predict(x_test)\n",
    "np.testing.assert_allclose(predictions, new_prediction, atol=1e-6) # 预测结果一样\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.2 导出为SavedModel文件\n",
    "SavedModel是Tensorflow对象的独立序列化格式，支持使用Tensorflow Serving server来部署模型，支持其他语言读取。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "W1013 16:41:11.934549 140085562951488 util.py:144] Unresolved object in checkpoint: (root).optimizer.decay\n",
      "W1013 16:41:11.935179 140085562951488 util.py:144] Unresolved object in checkpoint: (root).optimizer.learning_rate\n",
      "W1013 16:41:11.935569 140085562951488 util.py:144] Unresolved object in checkpoint: (root).optimizer.momentum\n",
      "W1013 16:41:11.935910 140085562951488 util.py:144] Unresolved object in checkpoint: (root).optimizer.rho\n",
      "W1013 16:41:11.936488 140085562951488 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'rms' for (root).layer_with_weights-0.kernel\n",
      "W1013 16:41:11.937078 140085562951488 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'rms' for (root).layer_with_weights-0.bias\n",
      "W1013 16:41:11.937702 140085562951488 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'rms' for (root).layer_with_weights-1.kernel\n",
      "W1013 16:41:11.938094 140085562951488 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'rms' for (root).layer_with_weights-1.bias\n",
      "W1013 16:41:11.938518 140085562951488 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'rms' for (root).layer_with_weights-2.kernel\n",
      "W1013 16:41:11.938826 140085562951488 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'rms' for (root).layer_with_weights-2.bias\n",
      "W1013 16:41:11.939161 140085562951488 util.py:152] A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/alpha/guide/checkpoints#loading_mechanics for details.\n",
      "W1013 16:41:11.940948 140085562951488 util.py:144] Unresolved object in checkpoint: (root).optimizer\n",
      "W1013 16:41:11.941336 140085562951488 util.py:144] Unresolved object in checkpoint: (root).optimizer.iter\n",
      "W1013 16:41:11.941909 140085562951488 util.py:144] Unresolved object in checkpoint: (root).optimizer.decay\n",
      "W1013 16:41:11.944620 140085562951488 util.py:144] Unresolved object in checkpoint: (root).optimizer.learning_rate\n",
      "W1013 16:41:11.945564 140085562951488 util.py:144] Unresolved object in checkpoint: (root).optimizer.momentum\n",
      "W1013 16:41:11.946080 140085562951488 util.py:144] Unresolved object in checkpoint: (root).optimizer.rho\n",
      "W1013 16:41:11.946562 140085562951488 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'rms' for (root).layer_with_weights-0.kernel\n",
      "W1013 16:41:11.946892 140085562951488 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'rms' for (root).layer_with_weights-0.bias\n",
      "W1013 16:41:11.947290 140085562951488 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'rms' for (root).layer_with_weights-1.kernel\n",
      "W1013 16:41:11.947617 140085562951488 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'rms' for (root).layer_with_weights-1.bias\n",
      "W1013 16:41:11.947977 140085562951488 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'rms' for (root).layer_with_weights-2.kernel\n",
      "W1013 16:41:11.948273 140085562951488 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'rms' for (root).layer_with_weights-2.bias\n",
      "W1013 16:41:11.948586 140085562951488 util.py:152] A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/alpha/guide/checkpoints#loading_mechanics for details.\n"
     ]
    }
   ],
   "source": [
    "# 导出为tf的SavedModel文件\n",
    "model.save('save_model', save_format='tf')\n",
    "# 从SavedModel文件中导入模型\n",
    "new_model = keras.models.load_model('save_model')\n",
    "\n",
    "new_prediction = new_model.predict(x_test)\n",
    "np.testing.assert_allclose(predictions, new_prediction, atol=1e-6) # 预测结果一样\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "SaveModel创建的文件包含：\n",
    "- 权重\n",
    "- 网络图"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.3 仅保存网络结构\n",
    "仅保持网络结构，这样导出的模型并未包含训练好的参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 获取网络结构配置\n",
    "config = model.get_config()\n",
    "reinitialized_model = keras.Model.from_config(config)\n",
    "\n",
    "new_prediction = reinitialized_model.predict(x_test)\n",
    "assert abs(np.sum(predictions-new_prediction)) >0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "也可以使用json保存网络结构"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 将网络结构导出为json格式\n",
    "json_config = model.to_json()\n",
    "reinitialized_model = keras.models.model_from_json(json_config)\n",
    "\n",
    "new_prediction = reinitialized_model.predict(x_test)\n",
    "assert abs(np.sum(predictions-new_prediction)) >0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.4 仅保存网络权重参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 获取网络权重\n",
    "weights = model.get_weights()\n",
    "# 对网络权重进行赋值\n",
    "model.set_weights(weights)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以把结构和参数保存结合起来"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = model.get_config()\n",
    "weights = model.get_weights()\n",
    "\n",
    "new_model = keras.Model.from_config(config) # config只能用keras.Model的这个api\n",
    "new_model.set_weights(weights)\n",
    "\n",
    "new_predictions = new_model.predict(x_test)\n",
    "np.testing.assert_allclose(predictions, new_predictions, atol=1e-6)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.5 完整的模型保存方法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 导出网络结构和权重\n",
    "json_config = model.to_json()\n",
    "with open('model_config.json', 'w') as json_file:\n",
    "    json_file.write(json_config)\n",
    "model.save_weights('path_to_my_weights.h5')\n",
    "# 载入网络结构和权重\n",
    "with open('model_config.json') as json_file:\n",
    "    json_config = json_file.read()\n",
    "new_model = keras.models.model_from_json(json_config)\n",
    "new_model.load_weights('path_to_my_weights.h5')\n",
    "\n",
    "new_predictions = new_model.predict(x_test)\n",
    "np.testing.assert_allclose(predictions, new_predictions, atol=1e-6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "W1013 16:49:21.690537 140085562951488 util.py:144] Unresolved object in checkpoint: (root).optimizer\n",
      "W1013 16:49:21.691003 140085562951488 util.py:144] Unresolved object in checkpoint: (root).optimizer.iter\n",
      "W1013 16:49:21.691583 140085562951488 util.py:144] Unresolved object in checkpoint: (root).optimizer.decay\n",
      "W1013 16:49:21.691886 140085562951488 util.py:144] Unresolved object in checkpoint: (root).optimizer.learning_rate\n",
      "W1013 16:49:21.692354 140085562951488 util.py:144] Unresolved object in checkpoint: (root).optimizer.momentum\n",
      "W1013 16:49:21.692843 140085562951488 util.py:144] Unresolved object in checkpoint: (root).optimizer.rho\n",
      "W1013 16:49:21.693265 140085562951488 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'rms' for (root).layer_with_weights-0.kernel\n",
      "W1013 16:49:21.693663 140085562951488 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'rms' for (root).layer_with_weights-0.bias\n",
      "W1013 16:49:21.693997 140085562951488 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'rms' for (root).layer_with_weights-1.kernel\n",
      "W1013 16:49:21.694370 140085562951488 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'rms' for (root).layer_with_weights-1.bias\n",
      "W1013 16:49:21.694779 140085562951488 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'rms' for (root).layer_with_weights-2.kernel\n",
      "W1013 16:49:21.695171 140085562951488 util.py:144] Unresolved object in checkpoint: (root).optimizer's state 'rms' for (root).layer_with_weights-2.bias\n",
      "W1013 16:49:21.695466 140085562951488 util.py:152] A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/alpha/guide/checkpoints#loading_mechanics for details.\n"
     ]
    }
   ],
   "source": [
    "# 当然也可以一步到位\n",
    "model.save('path_to_my_model.h5')\n",
    "del model\n",
    "model = keras.models.load_model('path_to_my_model.h5')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.6 权重保存格式\n",
    "有.h5或.keras后缀时保存为keras HDF5格式文件，否则默认为TensorFlow Checkpoint格式文件。可以使用save_format显式确定。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save_weights('weight_tf_savedmodel')\n",
    "model.save_weights('weight_tf_savedmodel.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save_weights('weight_tf_savedmodel_tf', save_format='tf')\n",
    "model.save_weights('weight_tf_savedmodel_h5', save_format='h5')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.7 子类模型权重保存\n",
    "子类模型的结构无法保存和序列化，只能保持参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 构建模型\n",
    "class ThreeLayerMLP(keras.Model):\n",
    "  \n",
    "    def __init__(self, name=None):\n",
    "        super(ThreeLayerMLP, self).__init__(name=name)\n",
    "        self.dense_1 = layers.Dense(64, activation='relu', name='dense_1')\n",
    "        self.dense_2 = layers.Dense(64, activation='relu', name='dense_2')\n",
    "        self.pred_layer = layers.Dense(10, activation='softmax', name='predictions')\n",
    "\n",
    "    def call(self, inputs):\n",
    "        x = self.dense_1(inputs)\n",
    "        x = self.dense_2(x)\n",
    "        return self.pred_layer(x)\n",
    "\n",
    "def get_model():\n",
    "    return ThreeLayerMLP(name='3_layer_mlp')\n",
    "\n",
    "model = get_model()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "首先，无法保存从未使用过的子类模型。\n",
    "\n",
    "这是因为需要在某些数据上调用子类模型才能创建其权重。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 60000 samples\n",
      "60000/60000 [==============================] - 2s 26us/sample - loss: 0.3128\n"
     ]
    }
   ],
   "source": [
    "# 训练模型\n",
    "(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n",
    "x_train = x_train.reshape(60000, 784).astype('float32') / 255\n",
    "x_test = x_test.reshape(10000, 784).astype('float32') / 255\n",
    "\n",
    "model.compile(loss='sparse_categorical_crossentropy',\n",
    "              optimizer=keras.optimizers.RMSprop())\n",
    "history = model.fit(x_train, y_train,\n",
    "                    batch_size=64,\n",
    "                    epochs=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "推荐的保存子类模型的方法是使用save_weights创建TensorFlow SavedModel检查点。\n",
    "\n",
    "该检查点将包含与模型关联的所有变量的值：\n",
    "- 图层的权重\n",
    "- 优化器的状态 \n",
    "- 与有状态模型指标关联的任何变量\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 保持权重参数\n",
    "model.save_weights('my_model_weights', save_format='tf')\n",
    "\n",
    "# 输出结果，供后面对比\n",
    "\n",
    "predictions = model.predict(x_test)\n",
    "first_batch_loss = model.train_on_batch(x_train[:64], y_train[:64])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "要还原模型，将需要访问创建模型对象的代码。\n",
    "\n",
    "请注意，为了恢复优化器状态和任何有状态度量的状态，应该先编译模型（使用与以前完全相同的参数）。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 读取保存的模型参数\n",
    "new_model = get_model()\n",
    "new_model.compile(loss='sparse_categorical_crossentropy',\n",
    "                  optimizer=keras.optimizers.RMSprop())\n",
    "\n",
    "#new_model.train_on_batch(x_train[:1], y_train[:1])\n",
    "\n",
    "new_model.load_weights('my_model_weights')\n",
    "\n",
    "new_predictions = new_model.predict(x_test)\n",
    "np.testing.assert_allclose(predictions, new_predictions, atol=1e-6)\n",
    "\n",
    "\n",
    "new_first_batch_loss = new_model.train_on_batch(x_train[:64], y_train[:64])\n",
    "assert first_batch_loss == new_first_batch_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
